rm(list = ls())

##### Modeling the restricted mean survival time with pseudo-value regression forests #####
# The code contains one simulation run for the main simulation in the main manuscript

library(simstudy)
library(survival)
library(riskRegression)
library(tidyverse)
library(magrittr)
library(pseudo)
library(ranger)
library(partykit)
library(batchtools)
library(survRM2)

source("functions.R") # load required functions

scenario   <- "Scenario1" # Scenario (choose "Scenario1", "Scenario2", "Scenario3", "Scenario4")
n          <- 1000        # number of individuals in the data set
p          <- 0.5         # p-th quantile of the observed times is chosen as time horizon
i          <- 1           # Number of simulation run (defines the seed, see below)
cens       <- 0.25        # censoring proportion

# define transition time (irrelevant for Scenarios 1 and 2, for Scenarios 3 and 4, this is the 70%-quantile of the observed times)
if(scenario %in% c("Scenario1", "Scenario2")) {
  transition <- NULL
}

if(scenario == "Scenario3") {
  if(cens == 0.25) {
    transition <- 2.00
  } else if(cens == 0.5) {
    transition <- 1.04
  } else if(cens == 0.75) {
    transition <- 0.55
  }
} 

if(scenario == "Scenario4") {
  if(cens == 0.25) {
    transition <- 2.00
  } else if(cens == 0.5) {
    transition <- 1.01
  } else if(cens == 0.75) {
    transition <- 0.37
  }
} 


# generate coefficients (held fix across the simulation runs)
set.seed(1234)
beta_1 <- switch(as.character(scenario), 
                 "Scenario1"        = round(runif(11, min = -1, max = 1), 1),
                 "Scenario2"        = round(runif(19, min = -1, max = 1), 1),
                 "Scenario3"        = round(runif(11, min = -1, max = 1), 1),
                 "Scenario4"        = round(runif(19, min = -1, max = 1), 1))


beta_1[11] <- switch(as.character(scenario), 
                     "Scenario1"        = -2,
                     "Scenario2"        = -2,
                     "Scenario3"        = -2,
                     "Scenario4"        = -2)


beta_2 <- beta_1


beta_2[11] <- switch(as.character(scenario), 
                     "Scenario1"        = -2,
                     "Scenario2"        = -2,
                     "Scenario3"        = +2,
                     "Scenario4"        = +2)

# interactions
interactions  <- switch(as.character(scenario), 
                        "Scenario1"        = NA,
                        "Scenario2"        = c("x1*x3", "x1*x4", "x2*x3", "x2*x5", "x4*x5", "x1*X2", "x2*X3", "x3*X4"),
                        "Scenario3"        = NA,
                        "Scenario4"        = c("x1*x3", "x1*x4", "x2*x3", "x2*x5", "x4*x5", "x1*X2", "x2*X3", "x3*X4"))

# determine overall tau as the p-th quantile of the observed survival times 
set.seed(1234)
tau_vec <- c()
for(k in 1:100) {
  d <- switch(as.character(scenario),
              "Scenario1"        = simulate_data_scenario1(n, beta = beta_1, cens = cens),
              "Scenario2"        = simulate_data_scenario2(n, beta = beta_1, interactions = interactions, cens = cens),
              "Scenario3"        = simulate_data_scenario3(n, beta_1 = beta_1, beta_2 = beta_2, transition = transition, cens = cens),
              "Scenario4"        = simulate_data_scenario4(n, beta_1 = beta_1, beta_2 = beta_2,
                                                           interactions = interactions, transition = transition, cens = cens))
  tau_vec[k] <- quantile(d$time, p)
}

tau  <- mean(tau_vec)

# start simulation
set.seed(12345*i)
sim.data <- switch(as.character(scenario),  
                   "Scenario1"        = simulate_data_scenario1(n, beta = beta_1, cens = cens), 
                   "Scenario2"        = simulate_data_scenario2(n, beta = beta_1, interactions = interactions, cens = cens), 
                   "Scenario3"        = simulate_data_scenario3(n, beta_1 = beta_1, beta_2 = beta_2, transition = transition, cens = cens),
                   "Scenario4"        = simulate_data_scenario4(n, beta_1 = beta_1, beta_2 = beta_2, 
                                                                interactions = interactions, transition = transition, cens = cens))



# get covariate structure 
covariates <- switch(as.character(scenario), 
                     "Scenario1"        = c(paste0("x", 1:5), paste0("X", 1:5), "trt", paste0("z", 1:5)),
                     "Scenario2"        = c(paste0("x", 1:5), paste0("X", 1:5), "trt", paste0("z", 1:5)),
                     "Scenario3"        = c(paste0("x", 1:5), paste0("X", 1:5), "trt", paste0("z", 1:5)),
                     "Scenario4"        = c(paste0("x", 1:5), paste0("X", 1:5), "trt", paste0("z", 1:5)))



# true covariates 
cov_true <- switch(as.character(scenario), 
                   "Scenario1"        = c(paste0("x", 1:5), paste0("X", 1:5), "trt"),
                   "Scenario2"        = c(paste0("x", 1:5), paste0("X", 1:5), "trt", interactions),
                   "Scenario3"        = c(paste0("x", 1:5), paste0("X", 1:5), "trt"),
                   "Scenario4"        = c(paste0("x", 1:5), paste0("X", 1:5), "trt", interactions))



# calculate pseudo-values 
sim.data <- sim.data %>%
  mutate(pseudo = pseudomean(time = time, event = status, tmax = tau))


summary(sim.data$pseudo)


## augmented data according to true data generating process
sim.data.aug <- switch(as.character(scenario), 
                       "Scenario0" = sim.data %>% dplyr::select(all_of(cov_true)), 
                       "Scenario1" = sim.data %>% dplyr::select(all_of(cov_true)), 
                       "Scenario2" = sim.data %>%
                         dplyr::select(all_of(cov_true[! cov_true %in% interactions])) %>%
                         mutate(x1_x3 = x1*x3,
                                x1_x4 = x1*x4,
                                x2_x3 = x2*x3,
                                x2_x5 = x2*x5,
                                x4_x5 = x4*x5,
                                x1_X2 = x1*X2,
                                x2_X3 = x2*X3,
                                x3_X4 = x3*X4),
                       "Scenario3" = sim.data %>% dplyr::select(all_of(cov_true)), 
                       "Scenario4" = sim.data %>%
                         dplyr::select(all_of(cov_true[! cov_true %in% interactions])) %>%
                         mutate(x1_x3 = x1*x3,
                                x1_x4 = x1*x4,
                                x2_x3 = x2*x3,
                                x2_x5 = x2*x5,
                                x4_x5 = x4*x5,
                                x1_X2 = x1*X2,
                                x2_X3 = x2*X3,
                                x3_X4 = x3*X4))

## augmented data according to true data generating process
sim.data.aug.noise <- switch(as.character(scenario), 
                             "Scenario0" = sim.data %>% dplyr::select(all_of(cov_true)), 
                             "Scenario1" = sim.data %>% dplyr::select(all_of(covariates)), 
                             "Scenario2" = sim.data %>%
                               dplyr::select(all_of(cov_true[! cov_true %in% interactions])) %>%
                               mutate(x1_x3 = x1*x3,
                                      x1_x4 = x1*x4,
                                      x2_x3 = x2*x3,
                                      x2_x5 = x2*x5,
                                      x4_x5 = x4*x5,
                                      x1_X2 = x1*X2,
                                      x2_X3 = x2*X3,
                                      x3_X4 = x3*X4),
                             "Scenario3" = sim.data %>% dplyr::select(all_of(cov_true)), 
                             "Scenario4" = sim.data %>%
                               dplyr::select(all_of(cov_true[! cov_true %in% interactions])) %>%
                               mutate(x1_x3 = x1*x3,
                                      x1_x4 = x1*x4,
                                      x2_x3 = x2*x3,
                                      x2_x5 = x2*x5,
                                      x4_x5 = x4*x5,
                                      x1_X2 = x1*X2,
                                      x2_X3 = x2*X3,
                                      x3_X4 = x3*X4))


# theoretical RMST 
RMST <- sapply(1:n, FUN = get_RMST, transition = transition, tau = tau, 
               beta_1 = beta_1, beta_2 = beta_2, shape = 1, scale = 1, data = sim.data.aug)  



# theoretical RMST for trt = 0
RMST_trt0 <- sapply(1:n, FUN = get_RMST, transition = transition, tau = tau, 
                    beta_1 = beta_1, beta_2 = beta_2, shape = 1, scale = 1, data = sim.data.aug %>% mutate(trt = 0))  

# theoretical RMST for trt = 1 
RMST_trt1 <- sapply(1:n, FUN = get_RMST, transition = transition, tau = tau, 
                    beta_1 = beta_1, beta_2 = beta_2, shape = 1, scale = 1, data = sim.data.aug %>% mutate(trt = 1))  


RMST_diff  <- RMST_trt1 - RMST_trt0

############### Kaplan Meier ###############

print("Kaplan Meier...")

km         <- rmst2(time = sim.data$time, status = sim.data$status, arm = sim.data$trt, tau = tau)
km_fit     <- survfit(Surv(time, status) ~ trt, data = sim.data)

S_dach_tmp <- summary(km_fit, times = tau)$surv

RMST_trt0_km <- as.numeric(km$RMST.arm0$rmst[1])
RMST_trt1_km <- as.numeric(km$RMST.arm1$rmst[1])

RMST_km <- ifelse(sim.data$trt == 0, RMST_trt0_km, RMST_trt1_km)

RMST_diff_km  <- rep(RMST_trt1_km - RMST_trt0_km, n)


############### Cox (Reference model) ###############

print("Cox model (Reference)...")
if(scenario %in% c("Scenario1", "Scenario2")) {
  
  cox_model    <- coxph(as.formula(paste0("Surv(time, status) ~ ", paste0(cov_true, collapse = " + "))), 
                        data = sim.data, x = TRUE)
  coef_cox     <- coef(cox_model)
  RMST_cox     <- sapply(1:n, FUN = get_RMST_cox, tau = tau, beta = coef_cox, data = sim.data.aug, model = cox_model)
  
  
  # trt0
  RMST_trt0_cox  <- sapply(1:n, FUN = get_RMST_cox, tau = tau, beta = coef_cox, 
                           data = sim.data.aug %>% mutate(trt = 0), model = cox_model)
  
  # trt1
  RMST_trt1_cox  <- sapply(1:n, FUN = get_RMST_cox, tau = tau, beta = coef_cox, 
                           data = sim.data.aug %>% mutate(trt = 1), model = cox_model)

  
  RMST_diff_cox  <- RMST_trt1_cox - RMST_trt0_cox

  
  trt_effect_cox <- as.numeric(coef_cox[names(coef_cox) == "trt"])
  
  
} else {
  
  transition <<- transition
  
  sim.data.aug.cox <- survSplit(Surv(time, status) ~., data = sim.data, cut = c(transition),
                                     episode = "gr", id = "newid", zero = -0.000000000001) %>%
    mutate(tstart = ifelse(tstart < 0, 0, tstart))
  

  
  
  cox_model <- coxph(as.formula(paste0("Surv(tstart, time, status) ~ ", paste0(c(cov_true, "trt:strata(gr)"), collapse = " + "))),
                     data =  sim.data.aug.cox, x = TRUE)
  
  coef_cox      <- coef(cox_model)
  

  
  if(scenario == "Scenario3") {
    coef_cox_1 <- coef_cox[1:11]
    coef_cox_2 <- coef_cox[1:11] + c(rep(0,10), coef_cox[12])
  } else if(scenario == "Scenario4") {
    coef_cox_1 <- coef_cox[1:19]
    coef_cox_2 <- coef_cox[1:19] + c(rep(0,10), coef_cox[20], rep(0,8))
  }
  
  
  
  RMST_cox     <- sapply(1:n, FUN = get_RMST_cox_transition, beta_1 = coef_cox_1, beta_2 = coef_cox_2, 
                         data = sim.data.aug, model = cox_model, transition = transition, tau = tau)
  
  
  # trt 0
  RMST_trt0_cox     <- sapply(1:n, FUN = get_RMST_cox_transition, beta_1 = coef_cox_1, beta_2 = coef_cox_2, 
                              data = sim.data.aug %>% mutate(trt = 0), model = cox_model, 
                              transition = transition, tau = tau)
  
  # trt 1
  RMST_trt1_cox     <- sapply(1:n, FUN = get_RMST_cox_transition, beta_1 = coef_cox_1, beta_2 = coef_cox_2, 
                              data = sim.data.aug %>% mutate(trt = 1), model = cox_model, 
                              transition = transition, tau = tau)
  

  RMST_diff_cox  <- RMST_trt1_cox - RMST_trt0_cox
  
  
}

############### Cox model (main effects only) ####################

sim.data.aug.noise <- sim.data %>% dplyr::select(-c("id", "time", "status", "pseudo"))

cox_model_false <- coxph(as.formula(paste0("Surv(time, status) ~ ", paste0(covariates, collapse = " + "))),
                         data =  sim.data, x = TRUE)

coef_cox_false  <- coef(cox_model_false)
RMST_cox_false  <- sapply(1:n, FUN = get_RMST_cox, tau = tau, beta = coef_cox_false, 
                          data = sim.data.aug.noise, model = cox_model_false)


# trt0
RMST_trt0_cox_false  <- sapply(1:n, FUN = get_RMST_cox, tau = tau, beta = coef_cox_false, 
                               data = sim.data.aug.noise %>% mutate(trt = 0), model = cox_model_false)

# trt1
RMST_trt1_cox_false  <- sapply(1:n, FUN = get_RMST_cox, tau = tau, beta = coef_cox_false, 
                               data = sim.data.aug.noise %>% mutate(trt = 1), model = cox_model_false)



RMST_diff_cox_false  <- RMST_trt1_cox_false - RMST_trt0_cox_false


trt_effect_cox <- as.numeric(coef_cox_false[names(coef_cox_false) == "trt"])

############### Lognormal ###############
print("Lognormal...")
lognormal_model <- survreg(as.formula(paste0("Surv(time, status) ~ ", paste0(covariates, collapse = " + "))),
                           data = sim.data %>% mutate(time = ifelse(time == 0, time + 0.001, time)), dist = "lognormal")
coef_lognormal  <- lognormal_model$coefficients
pred_lognormal  <- as.numeric(predict(lognormal_model, newdata = sim.data, type = "lp")) # linear predictor of the model
RMST_lognormal  <- sapply(pred_lognormal, FUN = my_integral_lognormal, tau = tau, scale = lognormal_model$scale)


pred_lognormal_0    <- predict(lognormal_model, type = "lp", newdata = sim.data %>% mutate(trt = 0))
pred_lognormal_1    <- predict(lognormal_model, type = "lp", newdata = sim.data %>% mutate(trt = 1))


RMST_trt0_lognormal <- sapply(pred_lognormal_0, FUN = my_integral_lognormal, tau = tau, scale = lognormal_model$scale)
RMST_trt1_lognormal <- sapply(pred_lognormal_1, FUN = my_integral_lognormal, tau = tau, scale = lognormal_model$scale)

RMST_diff_lognormal  <- RMST_trt1_lognormal - RMST_trt0_lognormal

bias_lognormal <- mean(RMST - RMST_lognormal)


# determine formula for "incorrect" GEE and forests (only main effects)
formula_models  <<- as.formula(paste0("pseudo ~ ", paste0(covariates, collapse = " + ")))

#### GEE ####
print("GEE...")


# identity link
gee_identity        <- geese(formula = formula_models, data = sim.data, 
                             id = id, jack = TRUE, family = gaussian, corstr = "independence", scale.fix = FALSE)
coef_gee_identity   <- gee_identity$beta
RMST_gee_identity   <- as.numeric(gee_identity$beta %*% t(cbind(1,sim.data[,2:(length(covariates)+1)])))


RMST_trt0_gee   <- as.numeric(gee_identity$beta %*% t(cbind(1,sim.data[,2:(length(covariates)+1)]) %>% mutate(trt = 0)))
RMST_trt1_gee   <- as.numeric(gee_identity$beta %*% t(cbind(1,sim.data[,2:(length(covariates)+1)]) %>% mutate(trt = 1)))


RMST_diff_gee  <- RMST_trt1_gee - RMST_trt0_gee


# loglink
gee_log           <- geese(formula = formula_models, data = sim.data, id = id, 
                           jack = TRUE, family = gaussian, mean.link = "log", corstr = "independence", scale.fix = FALSE)
coef_gee_log      <- gee_log$beta
RMST_gee_log      <- exp(as.numeric(gee_log$beta %*% t(cbind(1,sim.data[,2:(length(covariates)+1)]))))

RMST_trt0_gee_log   <- exp(as.numeric(gee_log$beta %*% t(cbind(1,sim.data[,2:(length(covariates)+1)]) %>% mutate(trt = 0))))
RMST_trt1_gee_log   <- exp(as.numeric(gee_log$beta %*% t(cbind(1,sim.data[,2:(length(covariates)+1)]) %>% mutate(trt = 1))))

RMST_diff_gee_log  <- RMST_trt1_gee_log - RMST_trt0_gee_log


#### Random Forest ####
print("Random Forest...")
print("5-fold cv for mtry")

error    <- rep(0, length(covariates))
mtry_vec <- 1:(length(covariates))
folds    <- chunk(1:nrow(sim.data), n.chunks = 5)
for(mtry in (mtry_vec)) {
  print(mtry)
  error_inner <- rep(0,5)
  for(k in 1:5) {
    
    d_train <- sim.data %>% 
      slice(which(folds != k))
    
    d_test <- sim.data %>%
      slice(which(folds == k))
    
    # calculate pseudo-values on training and test data separately for cv
    d_train <- d_train %>%
      mutate(pseudo = pseudomean(time = time, event = status, tmax = tau))
    
    d_test <- d_test %>%
      mutate(pseudo = pseudomean(time = time, event = status, tmax = tau))
    
    rforest     <- ranger(formula = formula_models, data = d_train, replace = FALSE, mtry = mtry)
    #rforest     <- ranger(formula = formula_models, data = d_train, replace = FALSE, 
    #                      mtry = mtry, always.split.variables = "trt")
    preds       <- predict(rforest, data = d_test)$predictions
    error_inner[k] <- mean((d_test$pseudo - preds)^2)
  }
  
  error[mtry] <- mean(error_inner)
  
}

mtry_opt <- mtry_vec[which.min(error)]

rforest           <- ranger(formula = formula_models, data = sim.data, replace = FALSE, mtry = mtry_opt)
RMST_rforest      <- predict(rforest, data = sim.data)$predictions


RMST_trt0_rforest <- predict(rforest, data = sim.data %>% mutate(trt = 0))$predictions
RMST_trt1_rforest <- predict(rforest, data = sim.data %>% mutate(trt = 1))$predictions


RMST_diff_rforest  <- RMST_trt1_rforest - RMST_trt0_rforest




#### Conditional random forest ####
print("Conditional random forest...")
print("5-fold cv for mtry")
error    <- rep(0, length(covariates))
mtry_vec <- 1:length(covariates)
folds    <- chunk(1:nrow(sim.data), n.chunks = 5)
for(mtry in mtry_vec) {
  print(mtry)
  error_inner <- rep(0,5)
  for(k in 1:5) {
    
    d_train <- sim.data %>% 
      slice(which(folds != k))
    
    d_test <- sim.data %>%
      slice(which(folds == k))
    
    # calculate pseudo-values on training and test data separately for cv
    d_train <- d_train %>%
      mutate(pseudo = pseudomean(time = time, event = status, tmax = tau))
    
    d_test <- d_test %>%
      mutate(pseudo = pseudomean(time = time, event = status, tmax = tau))
    
    cforest     <- cforest(formula = formula_models, data = d_train, 
                           control = partykit::ctree_control(teststat = "maximum",
                                                             splitstat = "maximum", 
                                                             testtype = "Univariate"),
                           mtry = mtry, cores = 1)
    preds       <- predict(cforest, newdata = d_test)
    error_inner[k] <- mean((d_test$pseudo - preds)^2)
  }
  
  error[mtry] <- mean(error_inner)
  
}

mtry_opt <- mtry_vec[which.min(error)]

cforest <- cforest(formula = formula_models, data = sim.data, 
                   control = partykit::ctree_control(teststat = "maximum",
                                                     splitstat = "maximum", 
                                                     testtype = "Univariate"),
                   mtry = mtry_opt, cores = 1)

RMST_cforest  <- predict(cforest, newdata = sim.data)


RMST_trt0_cforest <- predict(cforest, newdata = sim.data %>% mutate(trt = 0))
RMST_trt1_cforest <- predict(cforest, newdata = sim.data %>% mutate(trt = 1))

RMST_diff_cforest  <- RMST_trt1_cforest - RMST_trt0_cforest

bias_cforest <- mean(RMST - RMST_cforest)


# aggregate results
data_rmst <- cbind.data.frame(true = RMST, 
                              km = RMST_km, 
                              cox = RMST_cox,
                              lognormal = RMST_lognormal, 
                              gee = RMST_gee_identity,
                              gee_log = RMST_gee_log,
                              CART = RMST_rforest,
                              cforest = RMST_cforest,
                              cox_false = RMST_cox_false)


data_rmst_trt0 <- cbind.data.frame(true = RMST_trt0, 
                                   km = RMST_trt0_km, 
                                   cox = RMST_trt0_cox,
                                   lognormal = RMST_trt0_lognormal, 
                                   gee = RMST_trt0_gee,
                                   gee_log = RMST_trt0_gee_log,
                                   CART = RMST_trt0_rforest,
                                   cforest = RMST_trt0_cforest,
                                   cox_false = RMST_trt0_cox_false)

data_rmst_trt1 <- cbind.data.frame(true = RMST_trt1, 
                                   km = RMST_trt1_km, 
                                   cox = RMST_trt1_cox,
                                   lognormal = RMST_trt1_lognormal,
                                   gee = RMST_trt1_gee,
                                   gee_log = RMST_trt1_gee_log,
                                   CART = RMST_trt1_rforest,
                                   cforest = RMST_trt1_cforest,
                                   cox_false = RMST_trt1_cox_false)


data_rmst_diff2 <- cbind.data.frame(true = mean(RMST_diff), 
                                    km = mean(RMST_diff_km), 
                                    cox = mean(RMST_diff_cox),
                                    lognormal = mean(RMST_diff_lognormal),
                                    gee = mean(RMST_diff_gee),
                                    gee_log = mean(RMST_diff_gee_log),
                                    CART = mean(RMST_diff_rforest),
                                    cforest = mean(RMST_diff_cforest),
                                    cox_false = mean(RMST_diff_cox_false))

data_rmst_diff <- cbind.data.frame(true = 0, 
                                   km = mean((RMST_diff - RMST_diff_km)^2), 
                                   cox = mean((RMST_diff - RMST_diff_cox)^2),
                                   lognormal = mean((RMST_diff - RMST_diff_lognormal)^2),
                                   gee = mean((RMST_diff - RMST_diff_gee)^2),
                                   gee_log = mean((RMST_diff - RMST_diff_gee_log)^2),
                                   CART = mean((RMST_diff - RMST_diff_rforest)^2),
                                   cforest = mean((RMST_diff - RMST_diff_cforest)^2),
                                   cox_false = mean((RMST_diff - RMST_diff_cox_false)^2))



data_rmse <- cbind.data.frame(rmse_km = sqrt(mean((RMST_km - RMST)^2)),
                              rmse_cox = sqrt(mean((RMST_cox - RMST)^2)),
                              rmse_lognormal = sqrt(mean((RMST_lognormal - RMST)^2)),
                              rmse_GEE = sqrt(mean((RMST_gee_identity - RMST)^2)),
                              rmse_GEE_log = sqrt(mean((RMST_gee_log - RMST)^2)), 
                              rmse_CART = sqrt(mean((RMST_rforest - RMST)^2)),
                              rmse_cond = sqrt(mean((RMST_cforest - RMST)^2)),
                              rmse_cox_false = sqrt(mean((RMST_cox_false - RMST)^2)))


